import os
import sys
from pathlib import Path
import ast
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import statistics
import math
import seaborn as sns
import pandas as pd


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]


def get_results(directory):
    results = {}
    for i, path in enumerate(Path(directory).rglob('*/logs.pyd')):
        dirs = str(path.as_posix()).split("/")
        for j, dir in enumerate(dirs):
            if dir in ['class-il', 'task-il']:
                method = dirs[j+2]
                dataset = dirs[j+1]
                setting = dir
                if method not in results:
                    results[method] = {dataset: {setting: []}}
                if dataset not in results[method]:
                    results[method][dataset] = {setting: []}
                if setting not in results[method][dataset]:
                    results[method][dataset][setting] = []

                with open(path, "r") as data_file:
                    for line in data_file:
                        run = ast.literal_eval(line)
                        results[method][dataset][setting].append(run)
                break
            else:
                continue
    return results


def get_setting(run):
    setting = {'n_epochs': run['n_epochs'], 'lr': run['lr'], 'optim_mom': run['optim_mom'], 'optim_wd': run['optim_wd']}
    if run['model'] == 'derpp':
        setting['alpha'] = run['alpha']
        setting['beta'] = run['beta']
    if run['model'] == 'esmer':
        setting['loss_margin'] = run['loss_margin']
    return setting


def get_val_cv_hpo_acc(selected_hp, hpo_avg_perf_stats):
    for setting, acc in hpo_avg_perf_stats:
        if setting == selected_hp:
            return acc


def get_per_task_accs(run):
    return [run[id] for id in sorted(run, key=natural_keys) if "accmean_task" in id]


def get_eot_acc(run):
    task = 0
    acc = None
    for key in run:
        if 'accmean' in key:
            curent_task = int(key.split("k")[1])
            if curent_task > task:
                acc = run[key]
                task = curent_task
    return acc


def get_best_eot_val_settings(results):
    best_hpo_settings = {}
    for method in results:
        if method not in best_hpo_settings:
            best_hpo_settings[method] = {}
        for dataset in results[method]:
            if dataset not in best_hpo_settings[method]:
                best_hpo_settings[method][dataset] = {}
            for setting in results[method][dataset]:
                best_acc = -1
                best_hpo_setting = None
                for run in results[method][dataset][setting]:
                    if run['validation'] != 0:
                        hpo_setting = get_setting(run)
                        acc = get_eot_acc(run)
                        if acc >= best_acc:
                            best_hpo_setting = hpo_setting
                            best_acc = acc
                best_hpo_settings[method][dataset][setting] = best_hpo_setting
                print("For "+method+" "+dataset+" "+setting+": "+str(best_hpo_setting))
    return best_hpo_settings


def pretty_type_dataset_names(dataset_name):
    look_up = {"seq-cifar100": "CIFAR-100", "seq-cifar10": "CIFAR-10", "seq-tinyimg": "Tiny ImageNet",
               "hetro-cifar100": "Hetero CIFAR-100", "hetro-tinyimg": "Hetero Tiny ImageNet"}
    return look_up[dataset_name]


def pretty_type_method_names(method_name):
    look_up = {"derpp": "DER++", "er": "ER", "er_ace": "ER-ACE", "icarl": "iCaRL", "esmer": "ESMER", "sgd": "SGD"}
    return look_up[method_name]


def plot_hists(results, setting):
    #plt.style.use('seaborn')
    sns.set(font_scale=6)
    sns.set_theme(style="ticks")

    best_hpo_settings = {}
    #plt.close()
    for method in results:
        if method not in best_hpo_settings:
            best_hpo_settings[method] = {}
        for dataset in results[method]:
            sns.despine()
            plt.tick_params(labelsize=20)
            plt.grid(alpha=0.3)
            if dataset not in best_hpo_settings[method]:
                best_hpo_settings[method][dataset] = {}
            if setting in results[method][dataset]:
                print("For "+method+" "+dataset+" "+setting)
                accs = []
                for run in results[method][dataset][setting]:
                    if run['validation'] != 0:
                        accs.append(get_eot_acc(run))
                hist, bins = np.histogram(accs, bins=60)
                width = 0.7 * (bins[1] - bins[0])
                center = (bins[:-1] + bins[1:]) / 2
                #pd_accs = pd.DataFrame({"Average Accuracy": accs})
                #plt.bar(center, hist, align='center', width=width)
                sns.histplot(accs, binwidth=width*2)
                plt.title(pretty_type_dataset_names(dataset), fontsize=font_size)
                plt.xlabel("Average Accuracy", fontsize=font_size)
                plt.ylabel("Frequency", fontsize=font_size)
                plt.tight_layout()
                plt.show()

    return best_hpo_settings


def get_best_per_task_settings_with_acc(results, setting):
    best_hpo_settings = {}
    for method in results:
        if method not in best_hpo_settings:
            best_hpo_settings[method] = {}
        for dataset in results[method]:
            if 'selected_hp' not in results[method][dataset][setting][0]:
                continue
            selected_hps = results[method][dataset][setting][0]['selected_hp']
            val_accs = [get_val_cv_hpo_acc(selected_hps[i], results[method][dataset][setting][0]['hpo_avg_perf_stats'][i])
                        for i in range(len(selected_hps))]
            val_accs = 100*np.array(val_accs)
            task_il_test_accs = get_per_task_accs(results[method][dataset][setting][0])
            class_il_test_accs = get_per_task_accs(results[method][dataset]['class-il'][0])
            best_hpo_settings[method][dataset]= list(zip(selected_hps, val_accs, class_il_test_accs,
                                                                   task_il_test_accs))

    return best_hpo_settings


def get_mean_and_std_err_for_per_task_accs(runs):
    per_task_accs = [get_per_task_accs(run) for run in runs]
    per_task_accs = list(zip(*per_task_accs))
    per_task_stats = [(statistics.mean(accs), statistics.stdev(accs)/math.sqrt(len(accs))) for accs in per_task_accs]
    return per_task_stats


def get_best_results(results):
    best_hpo_settings = {}
    for method in results:
        if method not in best_hpo_settings:
            best_hpo_settings[method] = {}
        for dataset in results[method]:
            if len(results[method][dataset]['task-il']) == 1:
                task_il_test_accs = get_per_task_accs(results[method][dataset]['task-il'][0])
                class_il_test_accs = get_per_task_accs(results[method][dataset]['class-il'][0])
            else:
                task_il_test_accs = get_mean_and_std_err_for_per_task_accs(results[method][dataset]['task-il'])
                class_il_test_accs = get_mean_and_std_err_for_per_task_accs(results[method][dataset]['class-il'])
            if 'selected_hp' not in results[method][dataset]['class-il'][0]:
                selected_hps = [get_setting(results[method][dataset]['class-il'][0])]*len(class_il_test_accs)
            else:
                selected_hps = results[method][dataset]['class-il'][0]['selected_hp']
                #val_accs = [get_val_cv_hpo_acc(selected_hps[i], results[method][dataset]['class-il'][0]['hpo_avg_perf_stats'][i])
                #            for i in range(len(selected_hps))]
                #val_accs = 100*np.array(val_accs)
            best_hpo_settings[method][dataset] = list(zip(selected_hps, class_il_test_accs,
                                                          task_il_test_accs))

    return best_hpo_settings

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
font_size = 22
matplotlib.rcParams.update({'font.size': font_size})
if __name__ == "__main__":
    results = get_results(directory=sys.argv[1])
    #get_best_eot_val_settings(results)
    out = get_best_results(results)
    for method in out:
        for dataset in out[method]:
            print(str(method)+" "+str(dataset))
            for i, (hps, class_il_acc, task_il_acc) in enumerate(out[method][dataset]):
                #print("task "+str(i)+": "+"setting = "+str(hps)+", val_acc = "+str(val_acc)+", class_il_acc = "+str(class_il_acc)+", task_il_acc = "+str(task_il_acc))
                print("task " + str(i) + ": " + "setting = " + str(hps) + ", val_acc = " + ", class_il_acc = " + str(class_il_acc) + ", task_il_acc = " + str(task_il_acc))
    #plot_hists(results, "class-il")